"""Codes for Fig. 2 plot 
Created by Yuechuan Lin 
03-20-2022
Cornell University
"""
import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd 
import matplotlib 
from scipy import signal 
from scipy import ndimage
import matplotlib.ticker as tck
import matplotlib.ticker 
from matplotlib import gridspec
import scipy.io as sio
import h5py
import scipy.signal as ssignal
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm 
from matplotlib.colors import LightSource
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
#import svgutils.transform as sg
import sys
import seaborn as sns 
import pandas as pd 
from statannot import add_stat_annotation
from scipy import stats
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy.io import loadmat
from scipy.stats import norm
#print(matplotlib.matplotlib_fname()) # show where the font of matplotlib is 
# the figure size for two-column and one page should be 7.2 inch Width x 9.7 inch Height

# set font of plot 
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Helvetica']
font = {'weight': 'normal',
        'size'   : 8}
matplotlib.rc('font', **font)
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['mathtext.default'] = 'regular'

#matplotlib.font_manager._rebuild() # refresh Matplotlib font caches; only have to run once if you need to re-set all systematic fonts
saveAll = True  
# plot 3D scatter comparison between PAAM and C3 with normalized Gre 
def RemoveNan(PosData,AmechData,FradData,GData):
    AmechData = np.squeeze(AmechData)
    FradData = np.squeeze(FradData)
    GData = np.squeeze(GData) 
    AmechDatatmp = np.zeros([])
    FradDatatmp = np.zeros([])
    GDatatmp = np.zeros([])
    PosDatatmp = np.zeros([])
    T = 0 
    for i in range(np.size(GData)):
        if ~np.isnan(GData[i]):
            if T == 0:
                AmechDatatmp = AmechData[i]
                FradDatatmp = FradData[i]
                GDatatmp = GData[i]
                PosDatatmp = np.reshape(PosData[i,:],(1,3))
                T = T + 1
            else: 
                AmechDatatmp = np.append(AmechDatatmp,AmechData[i])
                FradDatatmp = np.append(FradDatatmp,FradData[i])
                GDatatmp = np.append(GDatatmp,GData[i])
                PosDatatmp = np.append(PosDatatmp,np.reshape(PosData[i,:],(1,3)),axis=0)
                T = T + 1
    return PosDatatmp,AmechDatatmp,FradDatatmp,GDatatmp



# import data 
# import collagen C3 data set 
fid_C3 = loadmat("./Fig2_data/CollagenSpatialProfile/ResultsC3.mat")
Pos_C3 = fid_C3['Results']['Pos'][0,0]
Amech_C3 = fid_C3['Results']['Amech'][0,0]
Frad_C3 = fid_C3['Results']['Frad'][0,0]
Gre_C3 = fid_C3['Results']['Gre'][0,0]
Pos_C3,Amech_C3,Frad_C3,Gre_C3 = RemoveNan(Pos_C3,Amech_C3,Frad_C3,Gre_C3)

# import PAAm 3T2C data 
fid_P2 = loadmat("./Fig2_data/PAAmSpatialProfile/Results.mat")
Pos_P2 = fid_P2['Results']['Pos'][0,0]
Amech_P2 = fid_P2['Results']['Amech'][0,0]
Frad_P2 = fid_P2['Results']['Frad'][0,0]
Gre_P2 = fid_P2['Results']['Gre'][0,0]
Pos_P2,Amech_P2,Frad_P2,Gre_P2 = RemoveNan(Pos_P2,Amech_P2,Frad_P2,Gre_P2)

def points3d_oce(PosData,GData,figHandle,ax,NumBeads,TextPos,cmap=cm.hsv,vrange=np.asarray([]),contourPlot=False):
    """
    PosData: as z,x,y
    GData: corresponding G value 
    """
    GData = np.abs((np.squeeze(GData) - np.median(np.squeeze(GData)))/np.median(np.squeeze(GData)))**2
    if np.size(vrange) == 0:
        vmin = np.amin(GData)
        vmax = np.amax(GData)
    elif np.size(vrange) == 2:
        vmin = vrange[0]
        vmax = vrange[1]
    else:
        raise ValueError("Vrange should have a dimension as [vmin,vmax]")
    Norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
    Cmap = cmap
    im = ax.scatter(PosData[:,1],PosData[:,2],PosData[:,0],s=15,c=GData,cmap=Cmap,norm=Norm,linewidths=0,alpha=0.7,edgecolors='face')
    ax.set_xlabel(r"$x$ ($\mu$m)",labelpad=-6)
    ax.set_ylabel(r"$y$ ($\mu$m)",labelpad=-7)
    ax.set_zlabel(r"$z$ ($\mu$m)",labelpad=-3) 
    ax.tick_params(axis='both', pad=-5)
    ax.tick_params(axis="z",pad=1)
    ax.view_init(elev=15, azim=45)
    axTmp = ax.inset_axes([0.67, 0.8, 0.25, 0.05])
    ax.set_title(r"$\left[\frac{\mathbf{G^{\prime}}-\mathbf{G^{\prime}}_{med}}{\mathbf{G^{\prime}}_{med}}\right]^2$")
    ax.text(TextPos[0],TextPos[1],TextPos[2],str(NumBeads)+" beads")
    #ax.text(45,-30,25,r'$\mathbf{G^{\prime}}$',fontsize=12,color='black')
    ScaIm = cm.ScalarMappable(norm=Norm,cmap=Cmap) 
    ScaIm.set_clim([vmin,vmax])
    cbar = figHandle.colorbar(ScaIm,cax=axTmp,orientation='horizontal',fraction=0.02,ticks=[vmin, vmax])
    #cbar.set_clim([vmin,vmax])
    axTmp.xaxis.set_ticks_position("top")
    cbar.ax.xaxis.set_tick_params(color='black',length=0.1,width=0.1)
    cbar.outline.set_edgecolor('white')
    cbar.ax.tick_params(labelsize=10,pad=0)
    plt.subplots_adjust(bottom=0.0)
    ax.w_xaxis.set_pane_color((242/255,241/255,239/255,0.5))
    ax.w_yaxis.set_pane_color((242/255,241/255,239/255,0.5))
    ax.w_zaxis.set_pane_color((242/255,241/255,239/255,0.5))
    return 1 

## 3D scatter plot with normalized Gre

figC3 = plt.figure(constrained_layout=False,figsize=(2.6,2.3))
plt.tight_layout(pad=0)
plt.subplots_adjust(wspace=0.5,hspace=0)
axC3 = plt.subplot2grid((1,1),(0,0),rowspan=1,colspan=1,projection='3d')

reC3 = points3d_oce(Pos_C3,Gre_C3,figC3,axC3,np.shape(Pos_C3)[0],[-36,45,28],vrange=[0.0,0.36])

figP2 = plt.figure(constrained_layout=False,figsize=(2.6,2.3))
plt.tight_layout(pad=0)
plt.subplots_adjust(wspace=0.5,hspace=0)
axP2 = plt.subplot2grid((1,1),(0,0),rowspan=1,colspan=1,projection='3d')
reP2 = points3d_oce(Pos_P2,Gre_P2,figP2,axP2,np.shape(Pos_P2)[0],[-18,65,25],vrange=[0.0,0.36])

# Statistic distribution of Gre and R 
# loading data 
fidStat = loadmat("./Fig2_data/ViolinPlot.mat")
GreStat = fidStat['ViolinPlot']['Gre'][0,0]
GimStat = fidStat['ViolinPlot']['Gim'][0,0]
SnamesStat = fidStat['ViolinPlot']['snames'][0,0]
GreStat_P2 = GreStat[SnamesStat=='P2']
GreStat_C3 = GreStat[SnamesStat=='C3']
GimStat_P2 = GimStat[SnamesStat=='P2']
GimStat_C3 = GimStat[SnamesStat=='C3'] 


figStat = plt.figure(figsize=(2.3,2.3))
plt.tight_layout(pad=5)
plt.subplots_adjust(wspace=0.1,hspace=0.5)
axS2 = plt.subplot2grid((20,20),(10,2),rowspan=9,colspan=16)
axS1 = plt.subplot2grid((20,20),(1,2),sharex=axS2,rowspan=9,colspan=16)

for axS in [axS1,axS2]:
    axS.spines['top'].set_visible(False)
    axS.spines['right'].set_visible(False)
    axS.xaxis.set_tick_params(which='major', size=3, width=0.5, direction='in', top=False,bottom=True)
    axS.xaxis.set_tick_params(which='minor', size=1.5, width=0.5, direction='in', top=False,bottom=True)
    axS.yaxis.set_tick_params(which='major', size=3, width=0.5, direction='in', left=True,right=False)
    axS.yaxis.set_tick_params(which='minor', size=1.5, width=0.5, direction='in', left=True,right=False)


# Fit a normal distribution to the data:
fitColor = '#808080'
mu_GreP2, std_GreP2 = norm.fit(GreStat_P2)
xGreP2 = np.linspace(np.amin(GreStat_P2)*0.8, np.amax(GreStat_P2)*1.1, 100)
pGreP2 = norm.pdf(xGreP2, mu_GreP2, std_GreP2)
axS1.plot(xGreP2, pGreP2, fitColor, linewidth=0.5,alpha=0.85)

mu_GimP2, std_GimP2 = norm.fit(GimStat_P2)
xGimP2 = np.linspace(np.amin(GimStat_P2)*0.8, np.amax(GimStat_P2)*1.1, 100)
pGimP2 = norm.pdf(xGimP2, mu_GimP2, std_GimP2)
axS1.plot(xGimP2, pGimP2, fitColor, linewidth=0.5,alpha=0.5)

mu_GreC3, std_GreC3 = norm.fit(GreStat_C3)
xGreC3 = np.linspace(np.amin(GreStat_C3)*0.8, np.amax(GreStat_C3)*1.1, 100)
pGreC3 = norm.pdf(xGreC3, mu_GreC3, std_GreC3)
axS2.plot(xGreC3, pGreC3, fitColor, linewidth=0.5,alpha=0.85)

mu_GimC3, std_GimC3 = norm.fit(GimStat_C3)
xGimC3 = np.linspace(np.amin(GimStat_C3)*0.8, np.amax(GimStat_C3)*1.1, 100)
pGimC3 = norm.pdf(xGimC3, mu_GimC3, std_GimC3)
axS2.plot(xGimC3, pGimC3, fitColor, linewidth=0.5,alpha=0.5)

tmpColor = ["#D81C28","#019BD8"]#sns.color_palette("Set2",n_colors=2).as_hex()
axS1.hist(GreStat_P2,bins=25,density=True,facecolor=tmpColor[1], alpha=1.0,label=r'$G^{\prime}$, $\sigma_{FWHM}= $'+str(np.round(2.35*std_GreP2,1))+" Pa")
axS1.hist(GimStat_P2,bins=25,density=True,facecolor=tmpColor[1], alpha=0.3,label=r'$R$, $\sigma_{FWHM}= $'+str(np.round(2.35*std_GimP2,1))+" Pa")
axS2.hist(GreStat_C3,bins=25,density=True,facecolor=tmpColor[0], alpha=1.0,label=r'$G^{\prime}$,$\sigma_{FWHM}= $'+str(np.round(2.35*std_GreC3,1))+" Pa")
axS2.hist(GimStat_C3,bins=25,density=True,facecolor= tmpColor[0], alpha=0.3,label=r'$R$,$\sigma_{FWHM}= $'+str(np.round(2.35*std_GimC3,1))+" Pa")
axS1.legend(fontsize=3.5,bbox_to_anchor=(0.55,0.65))#(270.751,0.013,511.869-270.751,0.01568-0.011051))
axS2.legend(fontsize=3.5,bbox_to_anchor=(0.58,0.75))

axS2.set_xlabel(r"$Mechanics$ (Pa)")
axS1.set_ylabel('Frequency')
axS2.set_ylabel('Frequency',labelpad=0)
plt.setp(axS1.get_xticklabels(), visible=False)

axS1.text(18,0.0144,"PAAm")
axS2.text(18,0.0075,"Collagen") 

## Plot correlated mechanics to the presence of fibrin (OCT intensity)
# compounding 
fidCorr = loadmat("./Fig2_data/CollagenCorrelation/ScatteringEffect.mat")
snameCorr = fidCorr['ScatteringEffect']['sample'][0,0]
ResultsCorr = fidCorr['ScatteringEffect']['results'][0,0]
GreCorr_P2 = ResultsCorr[np.squeeze(snameCorr==2),4]
GimCorr_P2 = ResultsCorr[np.squeeze(snameCorr==2),5]
GreCorr_C3 = ResultsCorr[np.squeeze(snameCorr==5),4]
GimCorr_C3 = ResultsCorr[np.squeeze(snameCorr==5),5]
CumInten_P2 = ResultsCorr[np.squeeze(snameCorr==2),1] 
CumInten_C3 = ResultsCorr[np.squeeze(snameCorr==5),1] 
GCorr_P2 = np.sqrt(GreCorr_P2**2+GimCorr_P2**2)
GCorr_C3 = np.sqrt(GreCorr_C3**2+GimCorr_C3**2) 



fidSpe = loadmat("./Fig2_data/collagenCorr_scatterPlotData_220217_All.mat")#loadmat("D:/LightsheetPFOCE/ManuscriptCodes/ReviewRevision1/NewResFromNikki/CollagenCorrelation/ScatterPlot.mat")
SpeGre = fidSpe['ScatterPlot']['Gre'][0,0]
SpeR = fidSpe['ScatterPlot']['R'][0,0]
SpeQ95 = fidSpe['ScatterPlot']['Q95'][0,0]/1e6
SpeMean = fidSpe['ScatterPlot']['Mean'][0,0]/1e6
#SpeMean = fidSpe['ScatterPlot']['VF'][0,0]
SpeStd = fidSpe['ScatterPlot']['VF'][0,0]#/1e6

figSpe = plt.figure(figsize=(2.7,2.7)) 
plt.tight_layout(pad=2)
plt.subplots_adjust(wspace=0.2,hspace=0.3)
axSpeG2 = plt.subplot2grid((26,20),(13,1),rowspan=6,colspan=9) # Std
axSpeG0 = plt.subplot2grid((26,20),(1,1),rowspan=6,colspan=9) # Q95
axSpeG1 = plt.subplot2grid((26,20),(7,1),rowspan=6,colspan=9) # Mean
axSpeG3 = plt.subplot2grid((26,20),(19,1),rowspan=6,colspan=9) # Fiber 

axSpeR2 = plt.subplot2grid((26,20),(13,10),rowspan=6,colspan=9)
axSpeR0 = plt.subplot2grid((26,20),(1,10),rowspan=6,colspan=9)
axSpeR1 = plt.subplot2grid((26,20),(7,10),rowspan=6,colspan=9)
axSpeR3 = plt.subplot2grid((26,20),(19,10),rowspan=6,colspan=9)

for axS in [axSpeR0,axSpeR1,axSpeR2,axSpeR3]:
    axS.yaxis.set_ticks_position("right")
    axS.spines['top'].set_visible(False)
    axS.spines['left'].set_visible(False)
    axS.xaxis.set_tick_params(which='major', size=3, width=0.5, direction='in', top=False,bottom=True)
    axS.xaxis.set_tick_params(which='minor', size=1.5, width=0.5, direction='in', top=False,bottom=True)
    axS.yaxis.set_tick_params(which='major', size=3, width=0.5, direction='in', left=False,right=True)
    axS.yaxis.set_tick_params(which='minor', size=1.5, width=0.5, direction='in', left=False,right=True)

for axS in [axSpeG0,axSpeG1,axSpeG2,axSpeG3]:
    axS.spines['top'].set_visible(False)
    axS.spines['right'].set_visible(False)
    axS.xaxis.set_tick_params(which='major', size=3, width=0.5, direction='in', top=False,bottom=True)
    axS.xaxis.set_tick_params(which='minor', size=1.5, width=0.5, direction='in', top=False,bottom=True)
    axS.yaxis.set_tick_params(which='major', size=3, width=0.5, direction='in', left=True,right=False)
    axS.yaxis.set_tick_params(which='minor', size=1.5, width=0.5, direction='in', left=True,right=False)

for axS in [axSpeG0,axSpeG1,axSpeG2,axSpeR0,axSpeR1,axSpeR2]:
    plt.setp(axS.get_xticklabels(), visible=False)


# fit all dataset
def LinearFit(x,y,ax,color='g',xlabel='',ylabel=''):
    x = np.squeeze(x)
    y = np.squeeze(y) 
    df = pd.DataFrame({"x":x,"y":y})
    axsns = sns.regplot(x='x',y='y',data=df,ci=95,robust=True,color=color,scatter_kws={"s":5,"edgecolors":"face","alpha":0.9},ax=ax,line_kws={"linewidth":1,"alpha":0.75})
    axsns.set_xlabel("")
    axsns.set_ylabel("")
    ax.set_xlabel(xlabel,fontsize=8,labelpad=-2)
    ax.set_ylabel(ylabel,fontsize=8)
    
    del df 
    return 1 


LinearFit(SpeGre,SpeQ95,axSpeG0,color=tmpColor[0],ylabel='Q95',xlabel="")  #sns.color_palette('Set2',n_colors=2)[0]
LinearFit(SpeR,SpeQ95,axSpeR0,color=tmpColor[1])  #sns.color_palette('Set2',n_colors=2)[1]
LinearFit(SpeGre,SpeMean,axSpeG1,color=tmpColor[0],ylabel='Mean',xlabel="")  #sns.color_palette('Set2',n_colors=2)[0]
LinearFit(SpeR,SpeMean,axSpeR1,color=tmpColor[1])#sns.color_palette('Set2',n_colors=2)[1] 
LinearFit(SpeGre,SpeStd,axSpeG2,color=tmpColor[0],ylabel="VF",xlabel="") #sns.color_palette('Set2',n_colors=2)[0]
LinearFit(SpeR,SpeStd,axSpeR2,color=tmpColor[1],xlabel="") #sns.color_palette('Set2',n_colors=2)[1]

#get fiber connectivitity 
fiberG = [346.11703,243.11528,204.83943,279.76581,353.40024,349.99121,373.80295,331.82129,295.55301,436.91718,411.05453,302.3111,397.80554,366.10907,235.40878] 
fiberR = [0.40779865,0.45176899,1.1399534,0.46698588,0.35126758,0.64596957,0.37702268,0.50162476,0.46381909,0.31771237,0.23560162,0.50213516,0.36518392,0.44572785,0.84793556] 
fiberCon = [4,3,3,3,4,5,5,5,3,7,5,2,4,5,2] 
LinearFit(fiberG,fiberCon,axSpeG3,color=tmpColor[0],ylabel='Fiber' +'\n' +'Connectivity',xlabel=r"$G^{\prime}$") 
#r"$R$"
LinearFit(fiberR,fiberCon,axSpeR3,color=tmpColor[1],xlabel=r"$R$")  




"""Confocal and PFOCE en-face imaging  
"""
fig2 = plt.figure(constrained_layout=False,figsize=(5,5))
plt.tight_layout(pad=0.0)
#plt.tight_layout()
#plt.margins(0,0)
plt.subplots_adjust(wspace=0.05,hspace=0.1)

axc4 = plt.subplot2grid((6,6),(0,0),rowspan=2,colspan=2)  
axc5 = plt.subplot2grid((6,6),(0,2),rowspan=2,colspan=2) 
axc6 = plt.subplot2grid((6,6),(0,4),rowspan=2,colspan=2)   
axc7 = plt.subplot2grid((6,6),(2,0),rowspan=2,colspan=2)  
axc8 = plt.subplot2grid((6,6),(2,2),rowspan=2,colspan=2) 
axc9 = plt.subplot2grid((6,6),(2,4),rowspan=2,colspan=2) 
axc10 = plt.subplot2grid((6,6),(4,0),rowspan=2,colspan=2)  
axc11 = plt.subplot2grid((6,6),(4,2),rowspan=2,colspan=2) 
axc12 = plt.subplot2grid((6,6),(4,4),rowspan=2,colspan=2) 

EnfaceData = ['OCE_collagen_EnfaceC1','OCE_collagen_EnfaceC2','OCE_collagen_EnfaceC3']
for item in EnfaceData:
    if item == 'EnfaceC111':
        OCEenface_tmp = sio.loadmat('./Fig2_data/'+ item +'.mat')
    else:
        OCEenface_tmp = sio.loadmat('./Fig2_data/'+ item +'.mat')
    OCEenface_tmp2 = OCEenface_tmp[item]
    OCEenface_tmp3 = OCEenface_tmp2['OCTImg'][0,0]
    OCEenface_tmp44 = OCEenface_tmp3['xAxis'][0,0]
    OCEenface_tmp444 = OCEenface_tmp3['yAxis'][0,0]
    OCEenface_tmp4 = OCEenface_tmp3['imgData'][0,0] 
    PFOCE_tmp1 = OCEenface_tmp2['Results'][0,0]
    PFOCE_tmp2 = PFOCE_tmp1['Pos'][0,0]
    PFOCE_tmp3 = np.reshape(PFOCE_tmp1['Gre'][0,0],(int(np.size(PFOCE_tmp1['Gre'][0,0])/4),4))
    PFOCE_tmp4 =np.reshape(PFOCE_tmp1['R'][0,0],(int(np.size(PFOCE_tmp1['R'][0,0])/4),4))
    if item == EnfaceData[0]:
        OCEenfaceC1 = OCEenface_tmp4
        OCEenfaceC1_x = OCEenface_tmp44 
        OCEenfaceC1_y = OCEenface_tmp444
        PFOCE_posC1 = PFOCE_tmp2 
        PFOCE_GreC1 = PFOCE_tmp3
        PFOCE_RC1 = PFOCE_tmp4
    elif item == EnfaceData[1]:
        OCEenfaceC2 = OCEenface_tmp4
        OCEenfaceC2_x = OCEenface_tmp44 
        OCEenfaceC2_y = OCEenface_tmp444
        PFOCE_posC2 = PFOCE_tmp2 
        PFOCE_GreC2 = PFOCE_tmp3
        PFOCE_RC2 = PFOCE_tmp4
    elif item == EnfaceData[2]:
        OCEenfaceC3 = OCEenface_tmp4
        OCEenfaceC3_x = OCEenface_tmp44 
        OCEenfaceC3_y = OCEenface_tmp444
        PFOCE_posC3 = PFOCE_tmp2 
        PFOCE_GreC3 = PFOCE_tmp3
        PFOCE_RC3 = PFOCE_tmp4
    else:
        print('Error in Data assignment')
print(np.shape(PFOCE_posC2))
OCEcmap = cm.turbo
OCEnorm = matplotlib.colors.Normalize(vmin=np.amin(np.asarray([np.amin(PFOCE_GreC1[:,0]),np.amin(PFOCE_GreC2[:,0]),np.amin(PFOCE_GreC3[:,0])])),vmax=np.amax(np.asarray([np.amax(PFOCE_GreC1[:,0]),np.amax(PFOCE_GreC2[:,0]),np.amax(PFOCE_GreC3[:,0])])))
OCEnorm2 = matplotlib.colors.Normalize(vmin=np.amin(np.asarray([np.amin(PFOCE_RC1[:,0]),np.amin(PFOCE_RC2[:,0]),np.amin(PFOCE_RC3[:,0])])),vmax=0.85*np.amax(np.asarray([np.amax(PFOCE_RC1[:,0]),np.amax(PFOCE_RC2[:,0]),np.amax(PFOCE_RC3[:,0])])))

def OCEenfacePlot(ax,OCTData,OCTData_x,OCTData_y,posData,OCEData,OCEcmapx,OCEnormx,scatterPlot = True,lb=0):
    OCTcmap = cm.gray 
    OCTData = np.transpose(OCTData) #np.fliplr(ndimage.rotate(OCTData,-90))
    OCEData = OCEData    
    #ax.set_aspect(aspect)
    if scatterPlot:
        OCTnorm = matplotlib.colors.Normalize(vmin=15, vmax=85)
        ext = [np.amin(OCTData_x),np.amax(OCTData_x),np.amax(OCTData_y),np.amin(OCTData_y)]
        # aspect= OCTData.shape[0]/float(OCTData.shape[1])*((ext[1]-ext[0])/(ext[3]-ext[2]))
        plt.setp(ax.get_xticklabels(), visible=False)
        plt.setp(ax.get_yticklabels(), visible=False)
        ax.tick_params(axis='both', which='both', length=0)
        ax.imshow(OCTData,cmap=OCTcmap,norm=OCTnorm,origin='upper',zorder=0,extent=ext) 
        if lb:
            ax.set_xlim(-36,OCTData_x[0,-1])
            ax.set_ylim(OCTData_y[0,24],OCTData_y[0,-1])
        else:
            ax.set_xlim(-32.2,OCTData_x[0,-1])
            ax.set_ylim(18.51,OCTData_y[0,-1])
        scaAx = ax.scatter(posData[:,0],posData[:,1],s=35,c=OCEData,norm=OCEnormx,cmap=OCEcmapx,linewidths=0,alpha=0.7,edgecolors='face',zorder=10)
        return scaAx
    else:
        ax.tick_params(axis='both', which='both', length=0)
        plt.setp(ax.get_xticklabels(), visible=False)
        plt.setp(ax.get_yticklabels(), visible=False)
        scaAx=ax.imshow(OCTData,cmap=OCTcmap,norm=OCEnormx,origin='upper',zorder=0)
        return scaAx
# import confocala dataset
Confocal_incubator = np.loadtxt("./Fig2_data/Confocal_collagen_enface_incubator.txt")
#in_rows, in_cols = np.where(Confocal_incubator>20000)
#in_med = np.median(np.median(Confocal_incubator))
#Confocal_incubator[in_rows,in_cols] = in_med
Confocal_Lu = np.loadtxt("./Fig2_data/Confocal_collagen_enface_multiplesteps_protocol.txt")
Confocal_RT = np.loadtxt("./Fig2_data/Confocal_collagen_enface_room_temperature.txt")
Confocal_incubator[Confocal_incubator<np.percentile(Confocal_incubator.flatten(),87)] = 0.1
Confocal_Lu[Confocal_Lu<np.percentile(Confocal_Lu.flatten(),80)] = 0.1 
Confocal_RT[Confocal_RT<np.percentile(Confocal_RT.flatten(),88)] = 0.1 

Confocal_incubator  = Confocal_incubator**0.4
Confocal_Lu = Confocal_Lu**0.4
Confocal_RT = Confocal_RT**0.4
ConfocalNorm_In = matplotlib.colors.Normalize(vmin=np.amin(Confocal_incubator),vmax=0.8*np.amax(Confocal_incubator))
ConfocalNorm_Lu = matplotlib.colors.Normalize(vmin=np.amin(Confocal_Lu),vmax=0.8*np.amax(Confocal_Lu))
ConfocalNorm__RT = matplotlib.colors.Normalize(vmin=np.amin(Confocal_RT),vmax=0.8*np.amax(Confocal_RT))
ConfocalNorm_All = matplotlib.colors.Normalize(vmin = 1.5*np.amin(np.asarray([np.amin(Confocal_incubator),np.amin(Confocal_Lu),np.amin(Confocal_RT)])),vmax = 0.8*np.amax(np.asarray([np.amax(Confocal_incubator),np.amax(Confocal_Lu),np.amax(Confocal_RT)])))

print("Confocal shape is:{}".format(np.shape(Confocal_incubator)))

OCEenfacePlot(axc4,Confocal_incubator,OCEenfaceC1_x,OCEenfaceC1_y,PFOCE_posC1,PFOCE_GreC1[:,0],OCEcmap,ConfocalNorm_All,scatterPlot=False) 
OCEenfacePlot(axc5,Confocal_RT,OCEenfaceC1_x,OCEenfaceC1_y,PFOCE_posC1,PFOCE_GreC1[:,0],OCEcmap,ConfocalNorm_All,scatterPlot=False)
OCEenfacePlot(axc6,Confocal_Lu,OCEenfaceC1_x,OCEenfaceC1_y,PFOCE_posC1,PFOCE_GreC1[:,0],OCEcmap,ConfocalNorm_All,scatterPlot=False) 

scaAxc7 = OCEenfacePlot(axc7,OCEenfaceC1,OCEenfaceC1_x,OCEenfaceC1_y,PFOCE_posC1,PFOCE_GreC1[:,0],OCEcmap,OCEnorm,scatterPlot=True) 
OCEenfacePlot(axc8,OCEenfaceC2,OCEenfaceC2_x,OCEenfaceC2_y,PFOCE_posC2,PFOCE_GreC2[:,0],OCEcmap,OCEnorm,scatterPlot=True)
OCEenfacePlot(axc9,OCEenfaceC3,OCEenfaceC3_x,OCEenfaceC3_y,PFOCE_posC3,PFOCE_GreC3[:,0],OCEcmap,OCEnorm,scatterPlot=True)  

scaAxc10 = OCEenfacePlot(axc10,OCEenfaceC1,OCEenfaceC1_x,OCEenfaceC1_y,PFOCE_posC1,PFOCE_RC1[:,0],OCEcmap,OCEnorm2,scatterPlot=True) 
OCEenfacePlot(axc11,OCEenfaceC2,OCEenfaceC2_x,OCEenfaceC2_y,PFOCE_posC2,PFOCE_RC2[:,0],OCEcmap,OCEnorm2,scatterPlot=True)
OCEenfacePlot(axc12,OCEenfaceC3,OCEenfaceC3_x,OCEenfaceC3_y,PFOCE_posC3,PFOCE_RC3[:,0],OCEcmap,OCEnorm2,scatterPlot=True)  


# get color bar of enface plot 
fig5 = plt.figure(constrained_layout=False,figsize=(2.2,1))
#plt.tight_layout(pad=0)
plt.tight_layout()
#plt.margins(0,0)
plt.subplots_adjust(wspace=0.6,hspace=3.2)

axcbar1 = plt.subplot2grid((16,17),(4,1),rowspan=4,colspan=15)
#axcbar2 = plt.subplot2grid((16,15),(8,0),rowspan=4,colspan=15)
cbarAxis = []


cbarRange1 = [np.amin(np.asarray([np.amin(PFOCE_GreC1[:,0]),np.amin(PFOCE_GreC2[:,0]),np.amin(PFOCE_GreC3[:,0])])),np.amax(np.asarray([np.amax(PFOCE_GreC1[:,0]),np.amax(PFOCE_GreC2[:,0]),np.amax(PFOCE_GreC3[:,0])]))]
cbarRange2 = [np.amin(np.asarray([np.amin(PFOCE_RC1[:,0]),np.amin(PFOCE_RC2[:,0]),np.amin(PFOCE_RC3[:,0])])),np.amax(np.asarray([np.amax(PFOCE_RC1[:,0]),np.amax(PFOCE_RC2[:,0]),np.amax(PFOCE_RC3[:,0])]))]


cbar = fig5.colorbar(scaAxc7,cax=axcbar1,aspect=80,orientation='horizontal',fraction=0.01,format='%.1f',ticks=cbarRange1)
axcbar11 = axcbar1.twiny()
cbar2 = fig5.colorbar(scaAxc10,cax=axcbar11,aspect=80,orientation='horizontal',fraction=0.01,format='%.1f',ticks=cbarRange2)
axcbar11.xaxis.set_ticks_position("bottom")
axcbar1.xaxis.set_ticks_position("top")
axcbar1.xaxis.set_minor_locator(AutoMinorLocator(4))
axcbar11.xaxis.set_minor_locator(AutoMinorLocator(4))
cbar.solids.set_edgecolor("face")
cbar2.solids.set_edgecolor("face")


#plot FWHM statistic distribution 
### Now Bar plot of FWHM of all samples
fig3 = plt.figure(figsize=(3,3)) 
ax3 = plt.subplot2grid((7,7),(1,1),rowspan=5,colspan=5)
BarData_dict = sio.loadmat("./Fig2_data/GreGimDistribution/BarPlot_FWHM.mat")
BarData = BarData_dict['BarPlot_FWHM']
BarFWHM = np.reshape(BarData['FWHM'][0,0],(5,2))


BarP1_G = pd.DataFrame({r'Dist. FWHM(Pa)':np.asarray([BarFWHM[0,0]]),'Sample':'P1','Modulus':r'$G^{\prime}$'})
BarP1_R = pd.DataFrame({r'Dist. FWHM(Pa)':np.asarray([BarFWHM[0,1]]),'Sample':'P1','Modulus':r'$G^{\prime\prime}$'})
BarP2_G = pd.DataFrame({r'Dist. FWHM(Pa)':np.asarray([BarFWHM[1,0]]),'Sample':'P2','Modulus':r'$G^{\prime}$'})
BarP2_R = pd.DataFrame({r'Dist. FWHM(Pa)':np.asarray([BarFWHM[1,1]]),'Sample':'P2','Modulus':r'$G^{\prime\prime}$'})
BarC1_G = pd.DataFrame({r'Dist. FWHM(Pa)':np.asarray([BarFWHM[2,0]]),'Sample':'C1','Modulus':r'$G^{\prime}$'})
BarC1_R = pd.DataFrame({r'Dist. FWHM(Pa)':np.asarray([BarFWHM[2,1]]),'Sample':'C1','Modulus':r'$G^{\prime\prime}$'})
BarC2_G = pd.DataFrame({r'Dist. FWHM(Pa)':np.asarray([BarFWHM[3,0]]),'Sample':'C2','Modulus':r'$G^{\prime}$'})
BarC2_R = pd.DataFrame({r'Dist. FWHM(Pa)':np.asarray([BarFWHM[3,1]]),'Sample':'C2','Modulus':r'$G^{\prime\prime}$'})
BarC3_G = pd.DataFrame({r'Dist. FWHM(Pa)':np.asarray([BarFWHM[4,0]]),'Sample':'C3','Modulus':r'$G^{\prime}$'})
BarC3_R = pd.DataFrame({r'Dist. FWHM(Pa)':np.asarray([BarFWHM[4,1]]),'Sample':'C3','Modulus':r'$G^{\prime\prime}$'})
BarAll_pd = pd.concat([BarP1_G,BarP1_R,BarP2_G,BarP2_R,BarC1_G,BarC1_R,BarC2_G,BarC2_R,BarC3_G,BarC3_R])
pivot_df = BarAll_pd.pivot(index='Sample', columns='Modulus', values=r'Dist. FWHM(Pa)')
pivot_df = pivot_df.reindex(index=['P1','P2','C1','C2','C3'])


pivot_df.loc[:,[r'$G^{\prime}$',r'$G^{\prime\prime}$']].plot.bar(stacked=True, color=tmpColor, ax=ax3)
ax3.set_ylabel(r'Dist. FWHM (Pa)')

plt.show()

